Source code for hysop.operator.plotters

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from abc import abstractmethod
import collections

from hysop.tools.htypes import to_tuple, check_instance, first_not_None
from hysop.tools.numpywrappers import npw
from hysop.tools.io_utils import IO
from hysop.core.graph.graph import op_apply
from hysop.core.graph.node_requirements import OperatorRequirements
from hysop.core.graph.computational_graph import ComputationalGraphOperator
from hysop.constants import TranspositionState, MemoryOrdering, Backend
from hysop.fields.continuous_field import ScalarField, TensorField
from hysop.parameters.scalar_parameter import ScalarParameter
from hysop.parameters.tensor_parameter import TensorParameter
from hysop.backend.host.host_operator import HostOperatorBase
from hysop.topology.topology_descriptor import TopologyDescriptor


[docs] class PlottingOperator(HostOperatorBase): """ Base operator for plotting. """
[docs] @classmethod def supports_mpi(cls): return True
[docs] @classmethod def supported_backends(cls): return Backend.all
def __new__( cls, name=None, dump_dir=None, update_frequency=1, save_frequency=100, axes_shape=(1,), figsize=(30, 18), visu_rank=None, fig=None, axes=None, force_backend=None, **kwds, ): return super().__new__(cls, **kwds) def __init__( self, name=None, dump_dir=None, update_frequency=1, save_frequency=100, axes_shape=(1,), figsize=(30, 18), visu_rank=0, fig=None, axes=None, force_backend=None, **kwds, ): import matplotlib import matplotlib.pyplot as plt check_instance(name, str) check_instance(update_frequency, int, minval=0) check_instance(save_frequency, int, minval=0) check_instance(axes_shape, tuple, minsize=1, allow_none=True) super().__init__(name=name, io_params=True, **kwds) if (fig is None) ^ (axes is None): msg = "figure and axes should be specified at the same time." raise RuntimeError(msg) dump_dir = first_not_None(dump_dir, IO.default_path()) imgpath = f"{dump_dir}/{name}_{{it:04d}}.png" if fig is None: fig, axes = plt.subplots(*axes_shape, figsize=figsize) fig.canvas.mpl_connect("key_press_event", self.on_key_press) fig.canvas.mpl_connect("close_event", self.on_close) axes = npw.asarray(axes).reshape(axes_shape) self.fig = fig self.axes = axes self.update_frequency = update_frequency self.save_frequency = save_frequency self.imgpath = imgpath self.should_draw = visu_rank == self.mpi_params.rank self.running = True self.first_draw = True self.plt = plt self.update_ioparams = self.io_params.clone( frequency=self.update_frequency, io_leader=visu_rank, visu_leader=visu_rank, with_last=True, ) self.save_ioparams = self.io_params.clone( frequency=self.save_frequency, io_leader=visu_rank, visu_leader=visu_rank, with_last=True, ) td_kwds = {} if force_backend is Backend.OPENCL: assert "cl_env" in kwds td_kwds["cl_env"] = kwds.pop("cl_env") self.force_backend = force_backend self.td_kwds = td_kwds
[docs] def create_topology_descriptors(self): # Here we recreate TopologyDescriptors to allow a forced backend # like a OpenCL mapped memory backend or when we do not want # to allocate memory for a topology that is just used for I/O. for field, topo_descriptor in self.input_fields.items(): topo_descriptor = TopologyDescriptor.build_descriptor( backend=self.force_backend, operator=self, field=field, handle=topo_descriptor, **self.td_kwds, ) self.input_fields[field] = topo_descriptor
[docs] def get_field_requirements(self): # set good transposition state and memory ordering requirements = super().get_field_requirements() for is_input, ireq in requirements.iter_requirements(): if ireq is None: continue (field, td, req) = ireq req.memory_order = MemoryOrdering.C_CONTIGUOUS req.axes = (TranspositionState[field.dim].default_axes(),) return requirements
[docs] def get_node_requirements(self): node_reqs = super().get_node_requirements() node_reqs.enforce_unique_transposition_state = True node_reqs.enforce_unique_topology_shape = False node_reqs.enforce_unique_memory_order = True node_reqs.enforce_unique_ghosts = False return node_reqs
[docs] def draw(self): if not self.should_draw or not self.running: return self.fig.canvas.draw() self.fig.show() if self.first_draw: self.plt.pause(1.0) self.first_draw = False else: self.plt.pause(0.01)
@op_apply def apply(self, **kwds): self._update(**kwds) self._save(**kwds) def _update(self, simulation, **kwds): if self.update_ioparams.should_dump(simulation=simulation): self.update(simulation=simulation, **kwds) self.draw() def _save(self, simulation, **kwds): if self.save_ioparams.should_dump(simulation=simulation): self.save(simulation=simulation, **kwds)
[docs] @abstractmethod def update(self, **kwds): pass
[docs] def save(self, simulation, **kwds): if self.should_draw: self.fig.savefig( self.imgpath.format(it=simulation.current_iteration), dpi=self.fig.dpi, bbox_inches="tight", )
[docs] def on_close(self, event): self.running = False
[docs] def on_key_press(self, event): key = event.key if key == "q": self.plt.close(self.fig) self.running = False
[docs] class FieldPlotter2D(PlottingOperator): """ Base operator to plot 2D fields at runtime. """ def __new__( cls, name, fields, variables, fig_title=None, imshow_kwds=None, add_colorbars=True, symmetric_cbar=False, fig=None, axes=None, shape=None, **kwds, ): return super().__new__(cls, **kwds) def __init__( self, name, fields, variables, fig_title=None, imshow_kwds=None, add_colorbars=True, symmetric_cbar=False, fig=None, axes=None, shape=None, **kwds, ): imshow_kwds = first_not_None(imshow_kwds, {}) imshow_kwds.setdefault("interpolation", "bilinear") imshow_kwds.setdefault("origin", "lower") imshow_kwds.setdefault("cmap", "bwr") def default_figtitle(simulation): return f"Fields at t={simulation.time}, iteration={simulation.current_iteration}" fig_title = first_not_None(fig_title, default_figtitle) assert callable(fig_title), "fig_title has to be a function." if not isinstance(variables, dict): variables = collections.defaultdict(lambda v=variables: v) if (fig is not None) and (axes is not None): check_instance(fields, dict, keys=matplotlib.axes.Axes, values=ScalarField) input_fields = {p: variables[p] for p in fields.values()} axes_shape = None elif isinstance(fields, dict): check_instance(fields, dict, keys=tuple, values=ScalarField) input_fields = {p: variables[p] for p in fields.values()} indices = npw.asarray(list(fields.keys()), dtype=npw.int32) assert indices.shape[-1] == 2, indices.shape assert (indices >= 0).all(), indices axes_shape = tuple(1 + indices.max(axis=0)) axes_shape = first_not_None(shape, axes_shape) elif isinstance(fields, (tuple, list)): check_instance(fields, (tuple, list), values=(TensorField, ScalarField)) input_fields = {p: variables[p] for p in fields} naxes = sum(f.nb_components for f in fields) axes_shape = first_not_None(shape, (1, naxes)) fields = dict(zip(range(naxes), sum(map(tuple, fields), ()))) check_instance(fields, dict, keys=int, values=ScalarField) else: raise TypeError(fields) assert all( field.dim == 2 for field in input_fields.keys() ), "Fields are not 2D." super().__init__( name=name, input_fields=input_fields, axes_shape=axes_shape, axes=axes, fig=fig, **kwds, ) self.fig.canvas.set_window_title("HySoP Field Plotter") self._plt_cfields = fields self._plt_dfields = None self._imshow_handles = None self._imshow_kwds = imshow_kwds self._add_colorbars = add_colorbars self._symmetric_cbar = symmetric_cbar self._fig_title = fig_title
[docs] def discretize(self): if self.discretized: return super().discretize() if not self.should_draw: return self._plt_dfields = {} self._imshow_handles = {} for axis_key, input_field in self._plt_cfields.items(): discrete_field = self.get_input_discrete_field(input_field) for scalar_dfield in discrete_field: if isinstance(axis_key, int): axis = self.axes.ravel()[axis_key] elif isinstance(axis_key, tuple): axis = self.axes[axis_key] else: axis = axis_key self._plt_dfields[axis] = scalar_dfield axis.set_title(scalar_dfield.name) axis.set_xlabel("x") axis.set_ylabel("y") data = npw.zeros( shape=discrete_field.mesh.grid_resolution, dtype=discrete_field.dtype, ) if self.should_draw: box = scalar_dfield.domain extent = (box.origin[1], box.end[1], box.origin[0], box.end[0]) img = axis.imshow(data, extent=extent, **self._imshow_kwds) self._imshow_handles[scalar_dfield] = img if self._add_colorbars: from mpl_toolkits.axes_grid1 import make_axes_locatable divider = make_axes_locatable(axis) cax = divider.append_axes("right", size="5%", pad=0.05) self.fig.colorbar(img, cax=cax, orientation="vertical") else: self._imshow_handles[scalar_dfield] = None
[docs] def update(self, simulation, **kwds): self.fig.suptitle(self._fig_title(simulation)) for dfield, handle in self._imshow_handles.items(): data = dfield.collect_data(leader=self.update_ioparams.io_leader) self._update_imshow_handle(handle, data)
def _update_imshow_handle(self, handle, data): if handle is None: return handle.set_data(data) dmin, dmax = data.min(), data.max() dinf = max(abs(dmin), abs(dmax)) if self._symmetric_cbar: handle.set_clim(-dinf, +dinf) else: handle.set_clim(+dmin, +dmax)
[docs] class ParameterPlotter(PlottingOperator): """ Base operator to plot parameters during runtime. """ def __init__( self, name, parameters, alloc_size=128, fig=None, axes=None, shape=None, **kwds ): input_params = set() if (fig is not None) and (axes is not None): import matplotlib custom_axes = True axes_shape = None check_instance(parameters, dict, keys=matplotlib.axes.Axes, values=dict) for params in parameters.values(): check_instance(params, dict, keys=str, values=ScalarParameter) input_params.update(set(params.values())) else: custom_axes = False _parameters = {} if isinstance(parameters, TensorParameter): _parameters[0] = parameters elif isinstance(parameters, (list, tuple)): for i, p in enumerate(parameters): _parameters[i] = p elif isinstance(parameters, dict): _parameters = parameters.copy() else: raise TypeError(type(parameters)) check_instance( _parameters, dict, keys=(int, tuple, list), values=(TensorParameter, list, tuple, dict), ) parameters = {} axes_shape = (1,) * 2 for pos, params in _parameters.items(): pos = to_tuple(pos) pos = (2 - len(pos)) * (0,) + pos check_instance(pos, tuple, values=int) axes_shape = tuple(max(p0, p1 + 1) for (p0, p1) in zip(axes_shape, pos)) if isinstance(params, dict): input_params.update({p.name: p for p in params.values()}) elif isinstance(params, TensorParameter): input_params[params.name] = params params = {params.name: params} elif isinstance(params, (list, tuple)): for p in params: input_params[p.name] = p params = {p.name: p for p in params} else: raise TypeError(type(params)) check_instance(params, dict, keys=str, values=TensorParameter) _params = {} for pname, p in params.items(): if isinstance(p, ScalarParameter): _params[pname] = p else: for idx in npw.ndindex(*p.shape): _pname = pname + f"_{idx}" _p = p.view(idx) _params[_pname] = _p parameters[pos] = _params super().__init__( name=name, input_params=input_params, axes_shape=axes_shape, axes=axes, fig=fig, **kwds, ) self.custom_axes = custom_axes data = {} lines = {} times = npw.empty(shape=(alloc_size,), dtype=npw.float32) for pos, params in parameters.items(): params_data = {} params_lines = {} for pname, p in params.items(): pdata = npw.empty(shape=(alloc_size,), dtype=p.dtype) pline = self.get_axes(pos).plot([], [], label=pname)[0] params_data[p] = pdata params_lines[p] = pline data[pos] = params_data lines[pos] = params_lines self.fig.canvas.set_window_title("HySoP Parameter Plotter") self.parameters = parameters self.times = times self.data = data self.lines = lines self.alloc_size = alloc_size self.counter = 0
[docs] def get_axes(self, pos): axes = self.axes if self.custom_axes: return pos else: return axes[pos]
def __getitem__(self, i): if self.custom_axes: return self.axes[i] else: return self.axes.flatten()[i]
[docs] def update(self, simulation, **kwds): # expand memory if required if self.counter + 1 > self.times.size: times = npw.empty(shape=(2 * self.times.size,), dtype=self.times.dtype) times[: self.times.size] = self.times self.times = times for pos, params in self.data.items(): for p, pdata in params.items(): new_pdata = npw.empty(shape=(2 * pdata.size,), dtype=pdata.dtype) new_pdata[: pdata.size] = pdata params[p] = new_pdata times, data, lines = self.times, self.data, self.lines times[self.counter] = simulation.t() for pos, params in self.parameters.items(): for pname, p in params.items(): data[pos][p][self.counter] = p() lines[pos][p].set_xdata(times[: self.counter]) lines[pos][p].set_ydata(data[pos][p][: self.counter]) self.counter += 1